<!DOCTYPE html>

<html lang="en">
  <head>
    <meta charset="utf-8" />
    <meta name="viewport" content="width=device-width, initial-scale=1.0" /><meta name="generator" content="Docutils 0.19: https://docutils.sourceforge.io/" />

    <meta charset="utf-8">
    <meta name="viewport" content="width=device-width, initial-scale=1, shrink-to-fit=no">
    <meta http-equiv="x-ua-compatible" content="ie=edge">
    
    <title>3.5.1. 多塔结构 &#8212; FunRec 推荐系统 0.0.1 documentation</title>

    <link rel="stylesheet" href="../../_static/material-design-lite-1.3.0/material.blue-deep_orange.min.css" type="text/css" />
    <link rel="stylesheet" href="../../_static/sphinx_materialdesign_theme.css" type="text/css" />
    <link rel="stylesheet" href="../../_static/fontawesome/all.css" type="text/css" />
    <link rel="stylesheet" href="../../_static/fonts.css" type="text/css" />
    <link rel="stylesheet" type="text/css" href="../../_static/pygments.css" />
    <link rel="stylesheet" type="text/css" href="../../_static/basic.css" />
    <link rel="stylesheet" type="text/css" href="../../_static/d2l.css" />
    <script data-url_root="../../" id="documentation_options" src="../../_static/documentation_options.js"></script>
    <script src="../../_static/jquery.js"></script>
    <script src="../../_static/underscore.js"></script>
    <script src="../../_static/_sphinx_javascript_frameworks_compat.js"></script>
    <script src="../../_static/doctools.js"></script>
    <script src="../../_static/sphinx_highlight.js"></script>
    <script src="../../_static/d2l.js"></script>
    <script async="async" src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
    <link rel="index" title="Index" href="../../genindex.html" />
    <link rel="search" title="Search" href="../../search.html" />
    <link rel="next" title="3.5.2. 动态权重建模" href="2.dynamic_weight.html" />
    <link rel="prev" title="3.5. 多场景建模" href="index.html" /> 
  </head>
<body>
    <div class="mdl-layout mdl-js-layout mdl-layout--fixed-header mdl-layout--fixed-drawer"><header class="mdl-layout__header mdl-layout__header--waterfall ">
    <div class="mdl-layout__header-row">
        
        <nav class="mdl-navigation breadcrumb">
            <a class="mdl-navigation__link" href="../index.html"><span class="section-number">3. </span>精排模型</a><i class="material-icons">navigate_next</i>
            <a class="mdl-navigation__link" href="index.html"><span class="section-number">3.5. </span>多场景建模</a><i class="material-icons">navigate_next</i>
            <a class="mdl-navigation__link is-active"><span class="section-number">3.5.1. </span>多塔结构</a>
        </nav>
        <div class="mdl-layout-spacer"></div>
        <nav class="mdl-navigation">
        
<form class="form-inline pull-sm-right" action="../../search.html" method="get">
      <div class="mdl-textfield mdl-js-textfield mdl-textfield--expandable mdl-textfield--floating-label mdl-textfield--align-right">
        <label id="quick-search-icon" class="mdl-button mdl-js-button mdl-button--icon"  for="waterfall-exp">
          <i class="material-icons">search</i>
        </label>
        <div class="mdl-textfield__expandable-holder">
          <input class="mdl-textfield__input" type="text" name="q"  id="waterfall-exp" placeholder="Search" />
          <input type="hidden" name="check_keywords" value="yes" />
          <input type="hidden" name="area" value="default" />
        </div>
      </div>
      <div class="mdl-tooltip" data-mdl-for="quick-search-icon">
      Quick search
      </div>
</form>
        
<a id="button-show-source"
    class="mdl-button mdl-js-button mdl-button--icon"
    href="../../_sources/chapter_2_ranking/5.multi_scenario/1.multi_tower.rst.txt" rel="nofollow">
  <i class="material-icons">code</i>
</a>
<div class="mdl-tooltip" data-mdl-for="button-show-source">
Show Source
</div>
        </nav>
    </div>
    <div class="mdl-layout__header-row header-links">
      <div class="mdl-layout-spacer"></div>
      <nav class="mdl-navigation">
          
              <a  class="mdl-navigation__link" href="https://funrec-notebooks.s3.eu-west-3.amazonaws.com/fun-rec.zip">
                  <i class="fas fa-download"></i>
                  Jupyter 记事本
              </a>
          
              <a  class="mdl-navigation__link" href="https://github.com/datawhalechina/fun-rec">
                  <i class="fab fa-github"></i>
                  GitHub
              </a>
      </nav>
    </div>
</header><header class="mdl-layout__drawer">
    
          <!-- Title -->
      <span class="mdl-layout-title">
          <a class="title" href="../../index.html">
              <span class="title-text">
                  FunRec 推荐系统
              </span>
          </a>
      </span>
    
    
      <div class="globaltoc">
        <span class="mdl-layout-title toc">Table Of Contents</span>
        
        
            
            <nav class="mdl-navigation">
                <ul>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_preface/index.html">前言</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_installation/index.html">安装</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_notation/index.html">符号</a></li>
</ul>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../../chapter_0_introduction/index.html">1. 推荐系统概述</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_0_introduction/1.intro.html">1.1. 推荐系统是什么？</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_0_introduction/2.outline.html">1.2. 本书概览</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_1_retrieval/index.html">2. 召回模型</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_1_retrieval/1.cf/index.html">2.1. 协同过滤</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/1.cf/1.itemcf.html">2.1.1. 基于物品的协同过滤</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/1.cf/2.usercf.html">2.1.2. 基于用户的协同过滤</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/1.cf/3.mf.html">2.1.3. 矩阵分解</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/1.cf/4.summary.html">2.1.4. 总结</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_1_retrieval/2.embedding/index.html">2.2. 向量召回</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/2.embedding/1.i2i.html">2.2.1. I2I召回</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/2.embedding/2.u2i.html">2.2.2. U2I召回</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/2.embedding/3.summary.html">2.2.3. 总结</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_1_retrieval/3.sequence/index.html">2.3. 序列召回</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/3.sequence/1.user_interests.html">2.3.1. 深化用户兴趣表示</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/3.sequence/2.generateive_recall.html">2.3.2. 生成式召回方法</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/3.sequence/3.summary.html">2.3.3. 总结</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">3. 精排模型</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../1.wide_and_deep.html">3.1. 记忆与泛化</a></li>
<li class="toctree-l2"><a class="reference internal" href="../2.feature_crossing/index.html">3.2. 特征交叉</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../2.feature_crossing/1.second_order.html">3.2.1. 二阶特征交叉</a></li>
<li class="toctree-l3"><a class="reference internal" href="../2.feature_crossing/2.higher_order.html">3.2.2. 高阶特征交叉</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../3.sequence.html">3.3. 序列建模</a></li>
<li class="toctree-l2"><a class="reference internal" href="../4.multi_objective/index.html">3.4. 多目标建模</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../4.multi_objective/1.arch.html">3.4.1. 基础结构演进</a></li>
<li class="toctree-l3"><a class="reference internal" href="../4.multi_objective/2.dependency_modeling.html">3.4.2. 任务依赖建模</a></li>
<li class="toctree-l3"><a class="reference internal" href="../4.multi_objective/3.multi_loss_optim.html">3.4.3. 多目标损失融合</a></li>
</ul>
</li>
<li class="toctree-l2 current"><a class="reference internal" href="index.html">3.5. 多场景建模</a><ul class="current">
<li class="toctree-l3 current"><a class="current reference internal" href="#">3.5.1. 多塔结构</a></li>
<li class="toctree-l3"><a class="reference internal" href="2.dynamic_weight.html">3.5.2. 动态权重建模</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_3_rerank/index.html">4. 重排模型</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_3_rerank/1.greedy.html">4.1. 基于贪心的重排</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_3_rerank/2.personalized.html">4.2. 基于个性化的重排</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_3_rerank/3.summary.html">4.3. 本章小结</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_4_trends/index.html">5. 难点及热点研究</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_4_trends/1.debias.html">5.1. 模型去偏</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_4_trends/2.cold_start.html">5.2. 冷启动问题</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_4_trends/3.generative.html">5.3. 生成式推荐</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_4_trends/4.summary.html">5.4. 本章小结</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_5_projects/index.html">6. 项目实践</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_5_projects/1.understanding.html">6.1. 赛题理解</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_5_projects/2.baseline.html">6.2. Baseline</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_5_projects/3.analysis.html">6.3. 数据分析</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_5_projects/4.recall.html">6.4. 多路召回</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_5_projects/5.feature_engineering.html">6.5. 特征工程</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_5_projects/6.ranking.html">6.6. 排序模型</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_appendix/index.html">7. Appendix</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_appendix/word2vec.html">7.1. Word2vec</a></li>
</ul>
</li>
</ul>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_references/references.html">参考文献</a></li>
</ul>

            </nav>
        
        </div>
    
</header>
        <main class="mdl-layout__content" tabIndex="0">

	<script type="text/javascript" src="../../_static/sphinx_materialdesign_theme.js "></script>
    <header class="mdl-layout__drawer">
    
          <!-- Title -->
      <span class="mdl-layout-title">
          <a class="title" href="../../index.html">
              <span class="title-text">
                  FunRec 推荐系统
              </span>
          </a>
      </span>
    
    
      <div class="globaltoc">
        <span class="mdl-layout-title toc">Table Of Contents</span>
        
        
            
            <nav class="mdl-navigation">
                <ul>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_preface/index.html">前言</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_installation/index.html">安装</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_notation/index.html">符号</a></li>
</ul>
<ul class="current">
<li class="toctree-l1"><a class="reference internal" href="../../chapter_0_introduction/index.html">1. 推荐系统概述</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_0_introduction/1.intro.html">1.1. 推荐系统是什么？</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_0_introduction/2.outline.html">1.2. 本书概览</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_1_retrieval/index.html">2. 召回模型</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_1_retrieval/1.cf/index.html">2.1. 协同过滤</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/1.cf/1.itemcf.html">2.1.1. 基于物品的协同过滤</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/1.cf/2.usercf.html">2.1.2. 基于用户的协同过滤</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/1.cf/3.mf.html">2.1.3. 矩阵分解</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/1.cf/4.summary.html">2.1.4. 总结</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_1_retrieval/2.embedding/index.html">2.2. 向量召回</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/2.embedding/1.i2i.html">2.2.1. I2I召回</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/2.embedding/2.u2i.html">2.2.2. U2I召回</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/2.embedding/3.summary.html">2.2.3. 总结</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_1_retrieval/3.sequence/index.html">2.3. 序列召回</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/3.sequence/1.user_interests.html">2.3.1. 深化用户兴趣表示</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/3.sequence/2.generateive_recall.html">2.3.2. 生成式召回方法</a></li>
<li class="toctree-l3"><a class="reference internal" href="../../chapter_1_retrieval/3.sequence/3.summary.html">2.3.3. 总结</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1 current"><a class="reference internal" href="../index.html">3. 精排模型</a><ul class="current">
<li class="toctree-l2"><a class="reference internal" href="../1.wide_and_deep.html">3.1. 记忆与泛化</a></li>
<li class="toctree-l2"><a class="reference internal" href="../2.feature_crossing/index.html">3.2. 特征交叉</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../2.feature_crossing/1.second_order.html">3.2.1. 二阶特征交叉</a></li>
<li class="toctree-l3"><a class="reference internal" href="../2.feature_crossing/2.higher_order.html">3.2.2. 高阶特征交叉</a></li>
</ul>
</li>
<li class="toctree-l2"><a class="reference internal" href="../3.sequence.html">3.3. 序列建模</a></li>
<li class="toctree-l2"><a class="reference internal" href="../4.multi_objective/index.html">3.4. 多目标建模</a><ul>
<li class="toctree-l3"><a class="reference internal" href="../4.multi_objective/1.arch.html">3.4.1. 基础结构演进</a></li>
<li class="toctree-l3"><a class="reference internal" href="../4.multi_objective/2.dependency_modeling.html">3.4.2. 任务依赖建模</a></li>
<li class="toctree-l3"><a class="reference internal" href="../4.multi_objective/3.multi_loss_optim.html">3.4.3. 多目标损失融合</a></li>
</ul>
</li>
<li class="toctree-l2 current"><a class="reference internal" href="index.html">3.5. 多场景建模</a><ul class="current">
<li class="toctree-l3 current"><a class="current reference internal" href="#">3.5.1. 多塔结构</a></li>
<li class="toctree-l3"><a class="reference internal" href="2.dynamic_weight.html">3.5.2. 动态权重建模</a></li>
</ul>
</li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_3_rerank/index.html">4. 重排模型</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_3_rerank/1.greedy.html">4.1. 基于贪心的重排</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_3_rerank/2.personalized.html">4.2. 基于个性化的重排</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_3_rerank/3.summary.html">4.3. 本章小结</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_4_trends/index.html">5. 难点及热点研究</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_4_trends/1.debias.html">5.1. 模型去偏</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_4_trends/2.cold_start.html">5.2. 冷启动问题</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_4_trends/3.generative.html">5.3. 生成式推荐</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_4_trends/4.summary.html">5.4. 本章小结</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_5_projects/index.html">6. 项目实践</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_5_projects/1.understanding.html">6.1. 赛题理解</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_5_projects/2.baseline.html">6.2. Baseline</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_5_projects/3.analysis.html">6.3. 数据分析</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_5_projects/4.recall.html">6.4. 多路召回</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_5_projects/5.feature_engineering.html">6.5. 特征工程</a></li>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_5_projects/6.ranking.html">6.6. 排序模型</a></li>
</ul>
</li>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_appendix/index.html">7. Appendix</a><ul>
<li class="toctree-l2"><a class="reference internal" href="../../chapter_appendix/word2vec.html">7.1. Word2vec</a></li>
</ul>
</li>
</ul>
<ul>
<li class="toctree-l1"><a class="reference internal" href="../../chapter_references/references.html">参考文献</a></li>
</ul>

            </nav>
        
        </div>
    
</header>

    <div class="document">
        <div class="page-content" role="main">
        
  <section id="multi-tower">
<span id="id1"></span><h1><span class="section-number">3.5.1. </span>多塔结构<a class="headerlink" href="#multi-tower" title="Permalink to this heading">¶</a></h1>
<p>在多目标建模领域，如 MMoE
所展现的那样，专家网络（Expert）用于挖掘不同任务之间共享的底层特征表示，而门控网络（Gate）则动态分配专家权重，根据不同任务的特性需求进行适配。这种由
“共享专家 + 任务专属门控”
构成的架构，具备了同时捕捉共性（共享专家所提取的通用特征）与特性（门控网络赋予的特定权重）的能力。</p>
<p>多场景建模与多任务学习类似但关注点不同：多任务学习处理相同场景/分布下的不同任务（如单样本同时预估CTR、CVR），而多场景建模处理不同场景/分布下的相同任务（如不同场景预估相同CTR）。前者是对于一条样本预估多个不同的目标值，后者是对于不同的样本预估相同的目标值。多场景建模若采用独立模型，会忽视场景共性，导致小场景效果差且资源消耗剧增；若混合样本训练单一模型，则会忽视场景差异，降低预测精度。</p>
<figure class="align-default" id="id3">
<span id="multi-tower-diff"></span><a class="reference internal image-reference" href="../../_images/star_1.png"><img alt="../../_images/star_1.png" src="../../_images/star_1.png" style="width: 500px;" /></a>
<figcaption>
<p><span class="caption-number">图3.5.1 </span><span class="caption-text">多目标与多场景建模的差异(图片来自阿里妈妈博客)</span><a class="headerlink" href="#id3" title="Permalink to this image">¶</a></p>
</figcaption>
</figure>
<p>本小节将会介绍基于多塔结构建模时，在利用多场景共性的前提下，显式地使用不同场景的信号来捕捉场景的特性。</p>
<section id="hmoe">
<h2><span class="section-number">3.5.1.1. </span>HMoE<a class="headerlink" href="#hmoe" title="Permalink to this heading">¶</a></h2>
<p>在多任务建模小节中，介绍了MMoE（Mixture-of-Experts）底层通过多专家网络作为多任务的共享特征，顶层对于不同的任务使用门控机制融合专家特征实现不同任务差异化的学习。在多场景建模中HMoE
<span id="id2">(<a class="reference internal" href="../../chapter_references/references.html#id95" title="Li, P., Li, R., Da, Q., Zeng, A.-X., &amp; Zhang, L. (2020). Improving multi-scenario learning to rank in e-commerce by exploiting task relationships in the label space. Proceedings of the 29th ACM International Conference on Information &amp; Knowledge Management (pp. 2605–2612).">Li <em>et al.</em>, 2020</a>)</span>
借鉴了MMoE的思路，底层同样适用多专家网络提取提取多个场景的特征作为共享特征，只不过顶层的多个塔不再是多个任务的输出，而是多个场景的输出，HMoE模型结构如下：</p>
<figure class="align-default" id="id4">
<span id="hmoe-model-structure"></span><a class="reference internal image-reference" href="../../_images/hmoe.png"><img alt="../../_images/hmoe.png" src="../../_images/hmoe.png" style="width: 300px;" /></a>
<figcaption>
<p><span class="caption-number">图3.5.2 </span><span class="caption-text">HMoE模型结构</span><a class="headerlink" href="#id4" title="Permalink to this image">¶</a></p>
</figcaption>
</figure>
<p>模型的底层使用多个专家抽取多个场景的特征，并通过一组门控网络将多个专家的输出结果进行融合，最后输入给上层不同的场景塔。</p>
<div class="math notranslate nohighlight" id="equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-0">
<span class="eqno">(3.5.1)<a class="headerlink" href="#equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-0" title="Permalink to this equation">¶</a></span>\[M(x) = \sum_{i=1}^{K} G_i(x) E_i(x)\]</div>
<p>原论文中是对于所有场景的塔都使用同一组门控融合后的专家特征，这种方式可以看成是多任务建模中的Shared-Bottom式的特征共享，只不过以多个FCN的融合输出替代了单个FCN的输出。从MMoE的经验来看，如果多个任务之间的相关性较差，底层这种特征硬共享可能会出现负迁移的现象。所以这种方式也不一定就是多场景建模的最优方案，也可以尝试对于不同的场景，使用不同门控融合后的专家特征。如第<span class="math notranslate nohighlight">\(t\)</span>个场景的输入特征表示为<span class="math notranslate nohighlight">\(M_t(x) = \sum_{i=1}^{K} G_i^t(x) E_i(x)\)</span>，最终哪种效果更好可以根据自己的场景做实验得到。</p>
<p>在得到了底层多场景特征之后，模型单场景的最终预估值不是简单的直接使用对应场景Tower打分，而是将多个场景输出打分融合为单个场景的打分。第<span class="math notranslate nohighlight">\(t\)</span>个场景的模型打分表示如下：</p>
<div class="math notranslate nohighlight" id="equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-1">
<span class="eqno">(3.5.2)<a class="headerlink" href="#equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-1" title="Permalink to this equation">¶</a></span>\[out_t = \sum_{i=1}^{T} W_i(x) S_i(x)\]</div>
<p>其中<span class="math notranslate nohighlight">\(W_i(x)\)</span>是场景<span class="math notranslate nohighlight">\(i\)</span>的融合权重，原论文中对于不同的场景下打分融合的<span class="math notranslate nohighlight">\(W\)</span>是否共享也未明确说明，但可以根据MMoE的思路，给每个场景都学习一个融合的权重，即第<span class="math notranslate nohighlight">\(t\)</span>个场景的预估值可以表示为：<span class="math notranslate nohighlight">\(out_t = \sum_{i=1}^{T} W_i^t(x) S_i(x)\)</span></p>
<p>从最终单场景由多个场景打分融合可以看出，对于某个场景<span class="math notranslate nohighlight">\(t\)</span>的样本，HMoE不仅需要计算它在场景<span class="math notranslate nohighlight">\(t\)</span>下的打分，还需要计算它在场景下的打分，计算场景<span class="math notranslate nohighlight">\(t\)</span>最终的打分时，其他场景的打分对<span class="math notranslate nohighlight">\(t\)</span>场景也是有参考价值的。</p>
<p>虽然在前向推理时可以将一条样本预估出不同场景的打分，但是对于某个场景<span class="math notranslate nohighlight">\(t\)</span>的样本来说应该只影响当前场景的参数（主要是场景塔的参数），否则<span class="math notranslate nohighlight">\(a\)</span>场景下的样本直接影响<span class="math notranslate nohighlight">\(b\)</span>场景的参数，很容易导致模型对于场景的感知下降，进而让整个多场景的模型效果变差。因此在计算融合打分时候，需要抑制其他场景打分的梯度回传，最终场景<span class="math notranslate nohighlight">\(t\)</span>的打分表示如下</p>
<div class="math notranslate nohighlight" id="equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-2">
<span class="eqno">(3.5.3)<a class="headerlink" href="#equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-2" title="Permalink to this equation">¶</a></span>\[out_t(x) = W_t(x) S_t(x) + \sum_{j=1, j \neq t}^{T} W_j(x) \underbrace{S_j(x)}_{\text{stop gradient}}\]</div>
<p>不共享融合权重的打分公式为：<span class="math notranslate nohighlight">\(out_t(x) = W_t^t(x) S_t(x) + \sum_{j=1, j \neq t}^{T} W_j^t(x) \underbrace{S_j(x)}_{\text{stop gradient}}\)</span></p>
<p>HMoE核心代码如下，其中包括了是否共享门控和融合打分权重的部分。</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># 构建dnn的输入</span>
<span class="n">dnn_inputs</span> <span class="o">=</span> <span class="n">concat_group_embedding</span><span class="p">(</span><span class="n">group_embedding_feature_dict</span><span class="p">,</span> <span class="s1">&#39;dnn&#39;</span><span class="p">)</span>

<span class="c1"># 创建多个专家</span>
<span class="n">expert_output_list</span> <span class="o">=</span> <span class="p">[</span>
    <span class="n">DNNs</span><span class="p">(</span><span class="n">shared_expert_dnn_units</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;expert_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)(</span><span class="n">dnn_inputs</span><span class="p">)</span>
    <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">shared_expert_nums</span><span class="p">)</span>
<span class="p">]</span>
<span class="n">expert_concat</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Lambda</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">stack</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">))(</span><span class="n">expert_output_list</span><span class="p">)</span>  <span class="c1"># (None, expert_num, dims)</span>

<span class="c1"># 每域独立 Gate 融合专家输出</span>
<span class="n">domain_tower_input_list</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_domains</span><span class="p">):</span>
    <span class="n">gate_output</span> <span class="o">=</span> <span class="n">DNNs</span><span class="p">(</span><span class="n">gate_dnn_units</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;gate_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)(</span><span class="n">dnn_inputs</span><span class="p">)</span>
    <span class="n">gate_output</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">shared_expert_nums</span><span class="p">,</span> <span class="n">use_bias</span><span class="o">=</span><span class="kc">False</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;softmax&#39;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;gate_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">_softmax&quot;</span><span class="p">)(</span><span class="n">gate_output</span><span class="p">)</span>
    <span class="n">gate_output</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Lambda</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">))(</span><span class="n">gate_output</span><span class="p">)</span>  <span class="c1"># (None, expert_num, 1)</span>
    <span class="n">gate_expert_output</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Lambda</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">])([</span><span class="n">gate_output</span><span class="p">,</span> <span class="n">expert_concat</span><span class="p">])</span>
    <span class="n">gate_expert_output</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Lambda</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">reduce_sum</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">keepdims</span><span class="o">=</span><span class="kc">False</span><span class="p">))(</span><span class="n">gate_expert_output</span><span class="p">)</span>
    <span class="n">domain_tower_input_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">gate_expert_output</span><span class="p">)</span>

<span class="c1"># 定义每个域的塔（Tower）</span>
<span class="n">domain_tower_output_list</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_domains</span><span class="p">):</span>
    <span class="n">domain_dnn_input</span> <span class="o">=</span> <span class="n">domain_tower_input_list</span><span class="p">[</span><span class="n">i</span><span class="p">]</span>
    <span class="n">task_output</span> <span class="o">=</span> <span class="n">DNNs</span><span class="p">(</span><span class="n">domain_tower_units</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;tower_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)(</span><span class="n">domain_dnn_input</span><span class="p">)</span>
    <span class="n">domain_tower_output_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">task_output</span><span class="p">)</span>

<span class="c1"># 域间权重（参数化，可选预处理）</span>
<span class="k">if</span> <span class="n">domain_weight_units</span> <span class="ow">is</span> <span class="ow">not</span> <span class="kc">None</span><span class="p">:</span>
    <span class="n">domain_weight_input</span> <span class="o">=</span> <span class="n">DNNs</span><span class="p">(</span><span class="n">domain_weight_units</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;domain_weight_dnn&quot;</span><span class="p">)(</span><span class="n">dnn_inputs</span><span class="p">)</span>
<span class="k">else</span><span class="p">:</span>
    <span class="n">domain_weight_input</span> <span class="o">=</span> <span class="n">dnn_inputs</span>
<span class="n">domain_weight</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dense</span><span class="p">(</span><span class="n">num_domains</span><span class="p">,</span> <span class="n">activation</span><span class="o">=</span><span class="s1">&#39;softmax&#39;</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;domain_weight&quot;</span><span class="p">)(</span><span class="n">domain_weight_input</span><span class="p">)</span>  <span class="c1"># (None, num_domains)</span>

<span class="c1"># 融合domain信息（own + cross with stop_gradient）</span>
<span class="n">mixed_output_list</span> <span class="o">=</span> <span class="p">[]</span>
<span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_domains</span><span class="p">):</span>
    <span class="n">wi</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Lambda</span><span class="p">(</span><span class="k">lambda</span> <span class="n">w</span><span class="p">:</span> <span class="n">w</span><span class="p">[:,</span> <span class="n">i</span><span class="p">:</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">])(</span><span class="n">domain_weight</span><span class="p">)</span>
    <span class="n">weighted_output</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Lambda</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">x</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span> <span class="o">*</span> <span class="n">x</span><span class="p">[</span><span class="mi">1</span><span class="p">])([</span><span class="n">domain_tower_output_list</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">wi</span><span class="p">])</span>
    <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_domains</span><span class="p">):</span>
        <span class="k">if</span> <span class="n">i</span> <span class="o">==</span> <span class="n">j</span><span class="p">:</span>
            <span class="k">continue</span>
        <span class="n">grad_output</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Lambda</span><span class="p">(</span><span class="k">lambda</span> <span class="n">x</span><span class="p">:</span> <span class="n">tf</span><span class="o">.</span><span class="n">stop_gradient</span><span class="p">(</span><span class="n">x</span><span class="p">))(</span><span class="n">domain_tower_output_list</span><span class="p">[</span><span class="n">j</span><span class="p">])</span>
        <span class="n">wj</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Lambda</span><span class="p">(</span><span class="k">lambda</span> <span class="n">w</span><span class="p">:</span> <span class="n">w</span><span class="p">[:,</span> <span class="n">j</span><span class="p">:</span><span class="n">j</span><span class="o">+</span><span class="mi">1</span><span class="p">])(</span><span class="n">domain_weight</span><span class="p">)</span>
        <span class="n">weighted_output</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Add</span><span class="p">()([</span>
            <span class="n">weighted_output</span><span class="p">,</span>
            <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Multiply</span><span class="p">()([</span><span class="n">wj</span><span class="p">,</span> <span class="n">grad_output</span><span class="p">])</span>
        <span class="p">])</span>
    <span class="n">mixed_output_list</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">weighted_output</span><span class="p">)</span>

<span class="c1"># 将所有domain的数据拼接成batch并输出 logits</span>
<span class="n">final_domain_output</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Concatenate</span><span class="p">(</span><span class="n">axis</span><span class="o">=</span><span class="mi">0</span><span class="p">)(</span><span class="n">mixed_output_list</span><span class="p">)</span>
<span class="n">dnn_logits</span> <span class="o">=</span> <span class="n">PredictLayer</span><span class="p">(</span><span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;dnn_logits&quot;</span><span class="p">)(</span><span class="n">final_domain_output</span><span class="p">)</span>
</pre></div>
</div>
<p><strong>代码实践</strong></p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="kn">from</span><span class="w"> </span><span class="nn">funrec</span><span class="w"> </span><span class="kn">import</span> <span class="n">run_experiment</span>

<span class="n">run_experiment</span><span class="p">(</span><span class="s1">&#39;hmoe&#39;</span><span class="p">)</span>
</pre></div>
</div>
<div class="output highlight-default notranslate"><div class="highlight"><pre><span></span><span class="o">+--------+--------+------------+</span>
<span class="o">|</span>    <span class="n">auc</span> <span class="o">|</span>   <span class="n">gauc</span> <span class="o">|</span>   <span class="n">val_user</span> <span class="o">|</span>
<span class="o">+========+========+============+</span>
<span class="o">|</span> <span class="mf">0.5972</span> <span class="o">|</span> <span class="mf">0.5471</span> <span class="o">|</span>        <span class="mi">217</span> <span class="o">|</span>
<span class="o">+--------+--------+------------+</span>
</pre></div>
</div>
</section>
<section id="star">
<h2><span class="section-number">3.5.1.2. </span>STAR<a class="headerlink" href="#star" title="Permalink to this heading">¶</a></h2>
<p>STAR（Star Topology Adaptive Recommender）:cite:<cite>sheng2021one</cite>
模型采用星型拓扑结构，实现场景私有参数和场景共享参数同时建模场景差异性和共性。场景私有参数以及场景共享参数最终聚合得到每个场景的模型。STAR结构如下图所示。</p>
<figure class="align-default" id="id5">
<span id="star-model-structure"></span><a class="reference internal image-reference" href="../../_images/star_2.png"><img alt="../../_images/star_2.png" src="../../_images/star_2.png" style="width: 500px;" /></a>
<figcaption>
<p><span class="caption-number">图3.5.3 </span><span class="caption-text">STAR模型结构</span><a class="headerlink" href="#id5" title="Permalink to this image">¶</a></p>
</figcaption>
</figure>
<p>相比于单场景的模型，STAR有三个针对多场景建模的创新思路值得学习，分别是星型拓扑结构的全连接网络（STAR
Topology Fully-Connected Network），Partitioned Normalization
以及辅助网络，下面将以此进行介绍。</p>
<p><strong>STAR Topology Fully-Connected Network</strong></p>
<p>星形拓扑全连接结构的核心思想是对于每一个全连接网络（FCN）都有场景共享和场景独占的部分，每个场景最终的参数由共享和独占参数通过element-wise
product融合计算得到。</p>
<figure class="align-default" id="id6">
<span id="star-fcn-structure"></span><a class="reference internal image-reference" href="../../_images/star_fcn.png"><img alt="../../_images/star_fcn.png" src="../../_images/star_fcn.png" style="width: 400px;" /></a>
<figcaption>
<p><span class="caption-number">图3.5.4 </span><span class="caption-text">STAR FCN结构</span><a class="headerlink" href="#id6" title="Permalink to this image">¶</a></p>
</figcaption>
</figure>
<p>具体而言，对于第<span class="math notranslate nohighlight">\(p\)</span>个场景的FCN的最终参数<span class="math notranslate nohighlight">\(W_p^{\star},b_p^{\star}\)</span>表示如下：</p>
<div class="math notranslate nohighlight" id="equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-3">
<span class="eqno">(3.5.4)<a class="headerlink" href="#equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-3" title="Permalink to this equation">¶</a></span>\[\begin{split}W_p^{\star} = W_p \otimes W \\
b_p^{\star} = b_p + b\end{split}\]</div>
<p>其中<span class="math notranslate nohighlight">\(W_p,W\)</span>分别表示第<span class="math notranslate nohighlight">\(p\)</span>个场景独有和全场景共享的参数，<span class="math notranslate nohighlight">\(b_p,b\)</span>也一样。</p>
<p>如果用<span class="math notranslate nohighlight">\(in_p\)</span>表示第<span class="math notranslate nohighlight">\(p\)</span>个场景FCN的输入，则该层星形FCN的输出<span class="math notranslate nohighlight">\(out_p\)</span>表示为：</p>
<div class="math notranslate nohighlight" id="equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-4">
<span class="eqno">(3.5.5)<a class="headerlink" href="#equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-4" title="Permalink to this equation">¶</a></span>\[out_p = \phi((W_p^\star)^\top in_p + b_p^\star),\]</div>
<p>其中<span class="math notranslate nohighlight">\(\phi\)</span>是激活函数。</p>
<p>STAR Topology Fully-Connected Network的具体实现如下：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">StarTopologyFCN</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    星型拓扑 FCN（简化版）</span>
<span class="sd">    - 核心思想：中心共享参数 + 各域的轻量适配器（增量参数），</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span>
                 <span class="n">num_domain</span><span class="p">,</span>
                 <span class="n">hidden_units</span><span class="p">,</span>
                 <span class="n">activation</span><span class="o">=</span><span class="s2">&quot;relu&quot;</span><span class="p">,</span>
                 <span class="n">dropout</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span>
                 <span class="n">l2_reg</span><span class="o">=</span><span class="mf">0.</span><span class="p">,</span>
                 <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">StarTopologyFCN</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">num_domain</span> <span class="o">=</span> <span class="n">num_domain</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">hidden_units</span> <span class="o">=</span> <span class="nb">list</span><span class="p">(</span><span class="n">hidden_units</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">activation</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">activations</span><span class="o">.</span><span class="n">get</span><span class="p">(</span><span class="n">activation</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Dropout</span><span class="p">(</span><span class="n">dropout</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">l2_reg</span> <span class="o">=</span> <span class="n">l2_reg</span>

    <span class="k">def</span><span class="w"> </span><span class="nf">build</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">input_shape</span><span class="p">):</span>
        <span class="c1"># inputs 形状为 (x, domain_index)，这里只取特征 x 的维度</span>
        <span class="n">input_shape</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
        <span class="n">input_dim</span> <span class="o">=</span> <span class="n">input_shape</span><span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>
        <span class="n">layer_dims</span> <span class="o">=</span> <span class="p">[</span><span class="n">input_dim</span><span class="p">]</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">hidden_units</span>

        <span class="c1"># 共享参数（星型中心）：每一层的共享权重 / 共享偏置</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">shared_kernels</span> <span class="o">=</span> <span class="p">[</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span>
                <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;shared_kernel_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
                <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">layer_dims</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">layer_dims</span><span class="p">[</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">]),</span>
                <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;glorot_uniform&quot;</span><span class="p">,</span>
                <span class="n">regularizer</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">regularizers</span><span class="o">.</span><span class="n">l2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">l2_reg</span><span class="p">),</span>
                <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
            <span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden_units</span><span class="p">))</span>
        <span class="p">]</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">shared_biases</span> <span class="o">=</span> <span class="p">[</span>
            <span class="bp">self</span><span class="o">.</span><span class="n">add_weight</span><span class="p">(</span>
                <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;shared_bias_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">,</span>
                <span class="n">shape</span><span class="o">=</span><span class="p">(</span><span class="n">layer_dims</span><span class="p">[</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">],),</span>
                <span class="n">initializer</span><span class="o">=</span><span class="s2">&quot;zeros&quot;</span><span class="p">,</span>
                <span class="n">trainable</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span>
            <span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden_units</span><span class="p">))</span>
        <span class="p">]</span>

        <span class="c1"># 域适配器（辐射边）：用 Embedding 为每个域生成增量参数（kernel / bias）</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">domain_kernel_embs</span> <span class="o">=</span> <span class="p">[</span>
            <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">num_domain</span><span class="p">,</span>
                <span class="n">layer_dims</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">*</span> <span class="n">layer_dims</span><span class="p">[</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span>
                <span class="n">embeddings_initializer</span><span class="o">=</span><span class="s2">&quot;glorot_uniform&quot;</span><span class="p">,</span>
                <span class="n">embeddings_regularizer</span><span class="o">=</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">regularizers</span><span class="o">.</span><span class="n">l2</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">l2_reg</span><span class="p">),</span>
            <span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden_units</span><span class="p">))</span>
        <span class="p">]</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">domain_bias_embs</span> <span class="o">=</span> <span class="p">[</span>
            <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Embedding</span><span class="p">(</span>
                <span class="bp">self</span><span class="o">.</span><span class="n">num_domain</span><span class="p">,</span>
                <span class="n">layer_dims</span><span class="p">[</span><span class="n">i</span><span class="o">+</span><span class="mi">1</span><span class="p">],</span>
                <span class="n">embeddings_initializer</span><span class="o">=</span><span class="s2">&quot;zeros&quot;</span><span class="p">,</span>
            <span class="p">)</span> <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden_units</span><span class="p">))</span>
        <span class="p">]</span>

    <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
        <span class="c1"># x: [B, Din]；domain_index: [B]</span>
        <span class="n">x</span><span class="p">,</span> <span class="n">domain_index</span> <span class="o">=</span> <span class="n">inputs</span>
        <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="nb">len</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">hidden_units</span><span class="p">)):</span>
            <span class="c1"># 取出域增量参数并重塑到矩阵形状</span>
            <span class="n">delta_w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">domain_kernel_embs</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">domain_index</span><span class="p">)</span>                      <span class="c1"># [B, Din*Dout]</span>
            <span class="n">delta_w</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">delta_w</span><span class="p">,</span> <span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,)</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_kernels</span><span class="p">[</span><span class="n">i</span><span class="p">]</span><span class="o">.</span><span class="n">shape</span><span class="o">.</span><span class="n">as_list</span><span class="p">())</span>  <span class="c1"># [B, Din, Dout]</span>
            <span class="c1"># 星型融合：共享权重 + 域增量（也可改为乘法，视实验而定）</span>
            <span class="n">w</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_kernels</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">delta_w</span>                                    <span class="c1"># [B, Din, Dout]</span>
            <span class="n">delta_b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">domain_bias_embs</span><span class="p">[</span><span class="n">i</span><span class="p">](</span><span class="n">domain_index</span><span class="p">)</span>                        <span class="c1"># [B, Dout]</span>
            <span class="n">b</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">shared_biases</span><span class="p">[</span><span class="n">i</span><span class="p">]</span> <span class="o">+</span> <span class="n">delta_b</span>                                     <span class="c1"># [B, Dout]</span>

            <span class="c1"># 线性层（按样本/域使用其专属参数）</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>                                           <span class="c1"># [B, 1, Din]</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">matmul</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">w</span><span class="p">)</span> <span class="o">+</span> <span class="n">tf</span><span class="o">.</span><span class="n">expand_dims</span><span class="p">(</span><span class="n">b</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>                              <span class="c1"># [B, 1, Dout]</span>
            <span class="n">x</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">squeeze</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>                                               <span class="c1"># [B, Dout]</span>

            <span class="c1"># 激活 + Dropout（保持轻量）</span>
            <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">activation</span><span class="p">(</span><span class="n">x</span><span class="p">)</span>
            <span class="n">x</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">dropout</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="n">training</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">x</span>
</pre></div>
</div>
<p><strong>Partitioned Normalization</strong></p>
<p>在神经网络训练时，为了加快模型的收敛常会在模型中加入BN(Batch
Normalization)。但是在多场景建模中，样本只在相同的场景内才满足独立同分布，多个场景混合的样本得到的统计量会忽略了不同场景独有的分布差异。为此应该让多场景中不同的场景独享统计量，这就是PN(Partitioned
Normalization)提出的主要动机。</p>
<p>在介绍PN之前，先简单回顾一下经典的BN的原理：</p>
<div class="math notranslate nohighlight" id="equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-5">
<span class="eqno">(3.5.6)<a class="headerlink" href="#equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-5" title="Permalink to this equation">¶</a></span>\[\mathbf{z'} = \gamma \frac{\mathbf{z} - \mathbf{E}}{\sqrt{\mathrm{Var} + \epsilon}} + \beta\]</div>
<p>其中<span class="math notranslate nohighlight">\(\mathbf{E},\mathrm{Var}\)</span>分别是移动的均值和方差，<span class="math notranslate nohighlight">\(\gamma,\beta\)</span>是可学习的参数用来对数据进行缩放和平移。</p>
<p>PN相比BN来说，不仅可学习的缩放和平移参数包括场景共享和独占两部分的参数，统计的移动均值和方差也是在不同场景样本上得到的，具体表示如下：</p>
<div class="math notranslate nohighlight" id="equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-6">
<span class="eqno">(3.5.7)<a class="headerlink" href="#equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-6" title="Permalink to this equation">¶</a></span>\[\mathbf{z'} = (\gamma * \gamma_p) \frac{\mathbf{z} - \mathbf{E_p}}{\sqrt{\mathrm{Var_p} + \epsilon}} + (\beta + \beta_p)\]</div>
<p>其中<span class="math notranslate nohighlight">\(\gamma,\beta\)</span>和<span class="math notranslate nohighlight">\(\gamma_p,\beta_p\)</span>分别表示场景共享和独占的参数，<span class="math notranslate nohighlight">\(\mathbf{E_p},\mathrm{Var_p}\)</span>表示在场景<span class="math notranslate nohighlight">\(p\)</span>的样本中统计得到的移动均值和方差。由于PN是基于Batch样本计算的，为了得到不同场景下更稳定的均值和方差，训练时的Batch
Size可以调的稍微大一些。</p>
<p>Partitioned Normalization的具体实现如下：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="k">class</span><span class="w"> </span><span class="nc">PartitionedNormalization</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">Layer</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    领域分区归一化（简化版）</span>
<span class="sd">    - 核心思想：按 domain 将 batch 切分，对每个子批次单独做 BN，</span>
<span class="sd">      再把结果写回原位置，避免跨域统计混淆。</span>
<span class="sd">    - 实现要点：掩码取出 -&gt; 逐域 BN -&gt; scatter 回填。</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="k">def</span><span class="w"> </span><span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">num_domain</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="o">**</span><span class="n">kwargs</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">PartitionedNormalization</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="n">name</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">num_domain</span> <span class="o">=</span> <span class="n">num_domain</span>
        <span class="c1"># 每个域一个 BN（带缩放与中心化），更直观体现“分区归一化”</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">bn_list</span> <span class="o">=</span> <span class="p">[</span>
            <span class="n">tf</span><span class="o">.</span><span class="n">keras</span><span class="o">.</span><span class="n">layers</span><span class="o">.</span><span class="n">BatchNormalization</span><span class="p">(</span><span class="n">center</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">scale</span><span class="o">=</span><span class="kc">True</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="sa">f</span><span class="s2">&quot;bn_</span><span class="si">{</span><span class="n">i</span><span class="si">}</span><span class="s2">&quot;</span><span class="p">)</span>
            <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">num_domain</span><span class="p">)</span>
        <span class="p">]</span>

    <span class="k">def</span><span class="w"> </span><span class="nf">_grid_indices</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">rows</span><span class="p">,</span> <span class="n">dim</span><span class="p">):</span>
        <span class="n">y</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">range</span><span class="p">(</span><span class="n">dim</span><span class="p">)</span>
        <span class="n">x_grid</span><span class="p">,</span> <span class="n">y_grid</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">meshgrid</span><span class="p">(</span><span class="n">rows</span><span class="p">,</span> <span class="n">y</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">transpose</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">x_grid</span><span class="p">,</span> <span class="n">y_grid</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">),</span> <span class="p">[</span><span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>

    <span class="k">def</span><span class="w"> </span><span class="nf">call</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">inputs</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
        <span class="c1"># x: [B, D]；domain_index: [B]</span>
        <span class="n">x</span><span class="p">,</span> <span class="n">domain_index</span> <span class="o">=</span> <span class="n">inputs</span>
        <span class="n">domain_index</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cast</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">domain_index</span><span class="p">,</span> <span class="p">[</span><span class="o">-</span><span class="mi">1</span><span class="p">]),</span> <span class="n">tf</span><span class="o">.</span><span class="n">int32</span><span class="p">)</span>
        <span class="n">dim</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">)[</span><span class="o">-</span><span class="mi">1</span><span class="p">]</span>

        <span class="n">y</span> <span class="o">=</span> <span class="n">x</span>
        <span class="k">for</span> <span class="n">i</span><span class="p">,</span> <span class="n">bn</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">bn_list</span><span class="p">):</span>
            <span class="n">mask</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">equal</span><span class="p">(</span><span class="n">domain_index</span><span class="p">,</span> <span class="n">i</span><span class="p">)</span>
            <span class="k">def</span><span class="w"> </span><span class="nf">update</span><span class="p">():</span>
                <span class="n">xi</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">boolean_mask</span><span class="p">(</span><span class="n">x</span><span class="p">,</span> <span class="n">mask</span><span class="p">)</span>                    <span class="c1"># 取该域样本</span>
                <span class="n">yi</span> <span class="o">=</span> <span class="n">bn</span><span class="p">(</span><span class="n">xi</span><span class="p">,</span> <span class="n">training</span><span class="o">=</span><span class="n">training</span><span class="p">)</span>                   <span class="c1"># 用该域统计做 BN</span>
                <span class="n">rows</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">boolean_mask</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">range</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">x</span><span class="p">)[</span><span class="mi">0</span><span class="p">]),</span> <span class="n">mask</span><span class="p">)</span>
                <span class="n">grid</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">_grid_indices</span><span class="p">(</span><span class="n">rows</span><span class="p">,</span> <span class="n">dim</span><span class="p">)</span>             <span class="c1"># 生成 (row, col) 坐标</span>
                <span class="k">return</span> <span class="n">tf</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">tensor_scatter_nd_update</span><span class="p">(</span><span class="n">y</span><span class="p">,</span> <span class="n">grid</span><span class="p">,</span> <span class="n">yi</span><span class="p">),</span> <span class="n">tf</span><span class="o">.</span><span class="n">shape</span><span class="p">(</span><span class="n">y</span><span class="p">))</span>

            <span class="n">y</span> <span class="o">=</span> <span class="n">tf</span><span class="o">.</span><span class="n">cond</span><span class="p">(</span><span class="n">tf</span><span class="o">.</span><span class="n">reduce_any</span><span class="p">(</span><span class="n">mask</span><span class="p">),</span> <span class="n">update</span><span class="p">,</span> <span class="k">lambda</span><span class="p">:</span> <span class="n">y</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">y</span>
</pre></div>
</div>
<p><strong>辅助网络</strong></p>
<p>为了进一步加强场景特征对模型输出的影响，在STAR中还会单独构建一个场景的辅助网络(Auxiliary
Network)，辅助网络将场景特征和其他特征共同输入到浅层网络中得到一个辅助的Logits，最终和主网络的Logits相加计算得到最终的CTR预估值：</p>
<div class="math notranslate nohighlight" id="equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-7">
<span class="eqno">(3.5.8)<a class="headerlink" href="#equation-chapter-2-ranking-5-multi-scenario-1-multi-tower-7" title="Permalink to this equation">¶</a></span>\[pCTR = Sigmoid(Logits_{main} + Logits_{aux})\]</div>
<p>STAR模型的实现代码如下：</p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="c1"># 核心思路：主干 Star FCN + 辅助 DNN，两者按域归一化后融合</span>
<span class="c1"># 1) 输入层：构建所有特征的输入，并取出域特征（用于分域）</span>
<span class="n">input_layer_dict</span> <span class="o">=</span> <span class="n">build_input_layer</span><span class="p">(</span><span class="n">feature_columns</span><span class="p">)</span>
<span class="n">domain_input</span> <span class="o">=</span> <span class="n">input_layer_dict</span><span class="p">[</span><span class="n">domain_feature_name</span><span class="p">]</span>

<span class="c1"># 2) 获取 dnn 组与 domain 组的表示</span>
<span class="n">dnn_inputs</span> <span class="o">=</span> <span class="n">concat_group_embedding</span><span class="p">(</span><span class="n">group_embedding_feature_dict</span><span class="p">,</span> <span class="s1">&#39;dnn&#39;</span><span class="p">)</span>          <span class="c1"># [B, dnn_dim]</span>
<span class="n">domain_embeddings</span> <span class="o">=</span> <span class="n">concat_group_embedding</span><span class="p">(</span><span class="n">group_embedding_feature_dict</span><span class="p">,</span> <span class="s1">&#39;domain&#39;</span><span class="p">)</span> <span class="c1"># [B, domain_dim]</span>

<span class="c1"># 3) 主干分支（Star FCN）：对 dnn</span>
<span class="n">fcn_inputs</span> <span class="o">=</span> <span class="n">PartitionedNormalization</span><span class="p">(</span><span class="n">num_domain</span><span class="o">=</span><span class="n">num_domains</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;fcn_pn_layer&quot;</span><span class="p">)(</span>
    <span class="p">[</span><span class="n">dnn_inputs</span><span class="p">,</span> <span class="n">domain_input</span><span class="p">]</span>
<span class="p">)</span>
<span class="n">fcn_output</span> <span class="o">=</span> <span class="n">StarTopologyFCN</span><span class="p">(</span><span class="n">num_domains</span><span class="p">,</span> <span class="n">star_dnn_units</span><span class="p">,</span> <span class="n">star_fcn_activation</span><span class="p">,</span>
                                <span class="n">dropout</span><span class="p">,</span> <span class="n">l2_reg</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;star_fcn_layer&quot;</span><span class="p">)([</span><span class="n">fcn_inputs</span><span class="p">,</span> <span class="n">domain_input</span><span class="p">])</span>
<span class="n">fcn_logit</span> <span class="o">=</span> <span class="n">PredictLayer</span><span class="p">(</span><span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;fcn_logits&#39;</span><span class="p">)(</span><span class="n">fcn_output</span><span class="p">)</span>

<span class="c1"># 4) 辅助分支（Aux DNN）：将域嵌入与 dnn 输入拼接，进行轻量 DNN</span>
<span class="n">aux_inputs</span> <span class="o">=</span> <span class="n">concat_func</span><span class="p">([</span><span class="n">domain_embeddings</span><span class="p">,</span> <span class="n">dnn_inputs</span><span class="p">],</span> <span class="n">axis</span><span class="o">=-</span><span class="mi">1</span><span class="p">)</span>                 <span class="c1"># [B, domain_dim + dnn_dim]</span>
<span class="n">aux_inputs</span> <span class="o">=</span> <span class="n">PartitionedNormalization</span><span class="p">(</span><span class="n">num_domain</span><span class="o">=</span><span class="n">num_domains</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s2">&quot;aux_pn_layer&quot;</span><span class="p">)(</span>
    <span class="p">[</span><span class="n">aux_inputs</span><span class="p">,</span> <span class="n">domain_input</span><span class="p">]</span>
<span class="p">)</span>
<span class="n">aux_output</span> <span class="o">=</span> <span class="n">DNNs</span><span class="p">(</span><span class="n">aux_dnn_units</span><span class="p">,</span> <span class="n">dropout_rate</span><span class="o">=</span><span class="n">dropout</span><span class="p">)(</span><span class="n">aux_inputs</span><span class="p">)</span>
<span class="n">aux_logit</span> <span class="o">=</span> <span class="n">PredictLayer</span><span class="p">(</span><span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;aux_logits&#39;</span><span class="p">)(</span><span class="n">aux_output</span><span class="p">)</span>

<span class="c1"># 5) 融合：主干与辅助分支的 logit 相加，得到最终的表征</span>
<span class="n">final_logits</span> <span class="o">=</span> <span class="n">add_tensor_func</span><span class="p">([</span><span class="n">fcn_logit</span><span class="p">,</span> <span class="n">aux_logit</span><span class="p">])</span>

<span class="n">final_prediction</span> <span class="o">=</span> <span class="n">PredictLayer</span><span class="p">(</span><span class="n">activation</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">name</span><span class="o">=</span><span class="s1">&#39;final_prediction&#39;</span><span class="p">)(</span><span class="n">final_logits</span><span class="p">)</span>
</pre></div>
</div>
<p><strong>代码实践</strong></p>
<div class="highlight-python notranslate"><div class="highlight"><pre><span></span><span class="n">run_experiment</span><span class="p">(</span><span class="s1">&#39;star&#39;</span><span class="p">)</span>
</pre></div>
</div>
<div class="output highlight-default notranslate"><div class="highlight"><pre><span></span><span class="o">+--------+--------+------------+</span>
<span class="o">|</span>    <span class="n">auc</span> <span class="o">|</span>   <span class="n">gauc</span> <span class="o">|</span>   <span class="n">val_user</span> <span class="o">|</span>
<span class="o">+========+========+============+</span>
<span class="o">|</span> <span class="mf">0.6648</span> <span class="o">|</span> <span class="mf">0.6244</span> <span class="o">|</span>        <span class="mi">693</span> <span class="o">|</span>
<span class="o">+--------+--------+------------+</span>
</pre></div>
</div>
</section>
</section>


        </div>
        <div class="side-doc-outline">
            <div class="side-doc-outline--content"> 
<div class="localtoc">
    <p class="caption">
      <span class="caption-text">Table Of Contents</span>
    </p>
    <ul>
<li><a class="reference internal" href="#">3.5.1. 多塔结构</a><ul>
<li><a class="reference internal" href="#hmoe">3.5.1.1. HMoE</a></li>
<li><a class="reference internal" href="#star">3.5.1.2. STAR</a></li>
</ul>
</li>
</ul>

</div>
            </div>
        </div>

      <div class="clearer"></div>
    </div><div class="pagenation">
     <a id="button-prev" href="index.html" class="mdl-button mdl-js-button mdl-js-ripple-effect mdl-button--colored" role="botton" accesskey="P">
         <i class="pagenation-arrow-L fas fa-arrow-left fa-lg"></i>
         <div class="pagenation-text">
            <span class="pagenation-direction">Previous</span>
            <div>3.5. 多场景建模</div>
         </div>
     </a>
     <a id="button-next" href="2.dynamic_weight.html" class="mdl-button mdl-js-button mdl-js-ripple-effect mdl-button--colored" role="botton" accesskey="N">
         <i class="pagenation-arrow-R fas fa-arrow-right fa-lg"></i>
        <div class="pagenation-text">
            <span class="pagenation-direction">Next</span>
            <div>3.5.2. 动态权重建模</div>
        </div>
     </a>
  </div>
        
        </main>
    </div>
  </body>
</html>